// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019-2020 Guillaume Becquin
// Copyright 2020 Maarten van Gompel
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//     http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! # Sequence classification pipeline (e.g. Sentiment Analysis)
//! More generic sequence classification pipeline, works with multiple models (Bert, Roberta)
//!
//! ```no_run
//! use rust_bert::pipelines::sequence_classification::SequenceClassificationConfig;
//! use rust_bert::resources::{RemoteResource};
//! use rust_bert::distilbert::{DistilBertModelResources, DistilBertVocabResources, DistilBertConfigResources};
//! use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
//! use rust_bert::pipelines::common::ModelType;
//! # fn main() -> anyhow::Result<()> {
//!
//! //Load a configuration
//! use rust_bert::pipelines::common::ModelResource;
//! let config = SequenceClassificationConfig::new(ModelType::DistilBert,
//!    ModelResource::Torch(Box::new(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2))),
//!    RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2),
//!    RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2),
//!    None, // Merge resources
//!    true, //lowercase
//!    None, //strip_accents
//!    None, //add_prefix_space
//! );
//!
//! //Create the model
//! let sequence_classification_model = SequenceClassificationModel::new(config)?;
//!
//! let input = [
//!     "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
//!     "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
//!     "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
//! ];
//! let output = sequence_classification_model.predict(&input);
//! # Ok(())
//! # }
//! ```
//! (Example courtesy of [IMDb](http://www.imdb.com))
//!
//! Output: \
//! ```no_run
//! # use rust_bert::pipelines::sequence_classification::Label;
//! let output =
//! [
//!    Label { text: String::from("POSITIVE"), score: 0.9986, id: 1, sentence: 0},
//!    Label { text: String::from("NEGATIVE"), score: 0.9985, id: 0, sentence: 1},
//!    Label { text: String::from("POSITIVE"), score: 0.9988, id: 1, sentence: 12},
//! ]
//! # ;
//! ```
use crate::albert::AlbertForSequenceClassification;
use crate::bart::BartForSequenceClassification;
use crate::bert::BertForSequenceClassification;
use crate::common::error::RustBertError;
use crate::deberta::DebertaForSequenceClassification;
use crate::distilbert::DistilBertModelClassifier;
use crate::fnet::FNetForSequenceClassification;
use crate::longformer::LongformerForSequenceClassification;
use crate::mobilebert::MobileBertForSequenceClassification;
use crate::pipelines::common::{
    cast_var_store, get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
};
use crate::reformer::ReformerForSequenceClassification;
use crate::resources::ResourceProvider;
use crate::roberta::RobertaForSequenceClassification;
use crate::xlnet::XLNetForSequenceClassification;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tch::nn::VarStore;
use tch::{no_grad, Device, Kind, Tensor};

use crate::deberta_v2::DebertaV2ForSequenceClassification;
#[cfg(feature = "onnx")]
use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
#[cfg(feature = "remote")]
use crate::{
    distilbert::{DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources},
    resources::RemoteResource,
};

#[derive(Debug, Serialize, Deserialize, Clone)]
/// # Label generated by a `SequenceClassificationModel`
pub struct Label {
    /// Label String representation
    pub text: String,
    /// Confidence score
    pub score: f64,
    /// Label ID
    pub id: i64,
    /// Sentence index
    #[serde(default)]
    pub sentence: usize,
}

/// # Configuration for SequenceClassificationModel
/// Contains information regarding the model to load and device to place the model on.
pub struct SequenceClassificationConfig {
    /// Model type
    pub model_type: ModelType,
    /// Model weights resource (default: pretrained BERT model on CoNLL)
    pub model_resource: ModelResource,
    /// Config resource (default: pretrained BERT model on CoNLL)
    pub config_resource: Box<dyn ResourceProvider + Send>,
    /// Vocab resource (default: pretrained BERT model on CoNLL)
    pub vocab_resource: Box<dyn ResourceProvider + Send>,
    /// Merges resource (default: None)
    pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
    /// Automatically lower case all input upon tokenization (assumes a lower-cased model)
    pub lower_case: bool,
    /// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
    pub strip_accents: Option<bool>,
    /// Flag indicating if the tokenizer should add a white space before each tokenized input (needed for some Roberta models)
    pub add_prefix_space: Option<bool>,
    /// Device to place the model on (default: CUDA/GPU when available)
    pub device: Device,
    /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
    pub kind: Option<Kind>,
}

impl SequenceClassificationConfig {
    /// Instantiate a new sequence classification configuration of the supplied type.
    ///
    /// # Arguments
    ///
    /// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
    /// * model - The `ResourceProvider` pointing to the model to load (e.g.  model.ot)
    /// * config - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
    /// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g.  vocab.txt/vocab.json)
    /// * vocab - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g.  merges.txt), needed only for Roberta.
    /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
    pub fn new<RC, RV>(
        model_type: ModelType,
        model_resource: ModelResource,
        config_resource: RC,
        vocab_resource: RV,
        merges_resource: Option<RV>,
        lower_case: bool,
        strip_accents: impl Into<Option<bool>>,
        add_prefix_space: impl Into<Option<bool>>,
    ) -> SequenceClassificationConfig
    where
        RC: ResourceProvider + Send + 'static,
        RV: ResourceProvider + Send + 'static,
    {
        SequenceClassificationConfig {
            model_type,
            model_resource,
            config_resource: Box::new(config_resource),
            vocab_resource: Box::new(vocab_resource),
            merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
            lower_case,
            strip_accents: strip_accents.into(),
            add_prefix_space: add_prefix_space.into(),
            device: Device::cuda_if_available(),
            kind: None,
        }
    }
}

#[cfg(feature = "remote")]
impl Default for SequenceClassificationConfig {
    /// Provides a defaultSST-2 sentiment analysis model (English)
    fn default() -> SequenceClassificationConfig {
        SequenceClassificationConfig::new(
            ModelType::DistilBert,
            ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
                DistilBertModelResources::DISTIL_BERT_SST2,
            ))),
            RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2),
            RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2),
            None,
            true,
            None,
            None,
        )
    }
}

#[allow(clippy::large_enum_variant)]
/// # Abstraction that holds one particular sequence classification model, for any of the supported models
pub enum SequenceClassificationOption {
    /// Bert for Sequence Classification
    Bert(BertForSequenceClassification),
    /// DeBERTa for Sequence Classification
    Deberta(DebertaForSequenceClassification),
    /// DeBERTa V2 for Sequence Classification
    DebertaV2(DebertaV2ForSequenceClassification),
    /// DistilBert for Sequence Classification
    DistilBert(DistilBertModelClassifier),
    /// MobileBert for Sequence Classification
    MobileBert(MobileBertForSequenceClassification),
    /// Roberta for Sequence Classification
    Roberta(RobertaForSequenceClassification),
    /// XLMRoberta for Sequence Classification
    XLMRoberta(RobertaForSequenceClassification),
    /// Albert for Sequence Classification
    Albert(AlbertForSequenceClassification),
    /// XLNet for Sequence Classification
    XLNet(XLNetForSequenceClassification),
    /// Bart for Sequence Classification
    Bart(BartForSequenceClassification),
    /// Reformer for Sequence Classification
    Reformer(ReformerForSequenceClassification),
    /// Longformer for Sequence Classification
    Longformer(LongformerForSequenceClassification),
    /// FNet for Sequence Classification
    FNet(FNetForSequenceClassification),
    /// ONNX Model for Sequence Classification
    #[cfg(feature = "onnx")]
    ONNX(ONNXEncoder),
}

impl SequenceClassificationOption {
    /// Instantiate a new sequence classification model of the supplied type.
    ///
    /// # Arguments
    ///
    /// * `SequenceClassificationConfig` - Sequence classification pipeline configuration. The type of model created will be inferred from the
    ///   `ModelResources` (Torch or ONNX) and `ModelType` (Architecture for Torch models) variants provided and
    pub fn new(config: &SequenceClassificationConfig) -> Result<Self, RustBertError> {
        match config.model_resource {
            ModelResource::Torch(_) => Self::new_torch(config),
            #[cfg(feature = "onnx")]
            ModelResource::ONNX(_) => Self::new_onnx(config),
        }
    }

    fn new_torch(config: &SequenceClassificationConfig) -> Result<Self, RustBertError> {
        let device = config.device;
        let weights_path = config.model_resource.get_torch_local_path()?;
        let mut var_store = VarStore::new(device);
        let model_config =
            &ConfigOption::from_file(config.model_type, config.config_resource.get_local_path()?);
        let model_type = config.model_type;
        let model = match model_type {
            ModelType::Bert => {
                if let ConfigOption::Bert(config) = model_config {
                    Ok(Self::Bert(
                        BertForSequenceClassification::new(var_store.root(), config)?,
                    ))
                } else {
                    Err(RustBertError::InvalidConfigurationError(
                        "You can only supply a BertConfig for Bert!".to_string(),
                    ))
                }
            }
            ModelType::Deberta => {
                if let ConfigOption::Deberta(config) = model_config {
                    Ok(Self::Deberta(
                        DebertaForSequenceClassification::new(var_store.root(), config)?,
                    ))
                } else {
                    Err(RustBertError::InvalidConfigurationError(
                        "You can only supply a DebertaConfig for DeBERTa!".to_string(),
                    ))
                }
            }
            ModelType::DebertaV2 => {
                if let ConfigOption::DebertaV2(config) = model_config {
                    Ok(Self::DebertaV2(
                        DebertaV2ForSequenceClassification::new(var_store.root(), config)?,
                    ))
                } else {
                    Err(RustBertError::InvalidConfigurationError(
                        "You can only supply a DebertaV2Config for DeBERTa V2!".to_string(),
                    ))
                }
            }
            ModelType::DistilBert => {
                if let ConfigOption::DistilBert(config) = model_config {
                    Ok(Self::DistilBert(
                        DistilBertModelClassifier::new(var_store.root(), config)?,
                    ))
                } else {
                    Err(RustBertError::InvalidConfigurationError(
                        "You can only supply a DistilBertConfig for DistilBert!".to_string(),
                    ))
                }
            }
            ModelType::MobileBert => {
                if let ConfigOption::MobileBert(config) = model_config {
                    Ok(Self::MobileBert(
                        MobileBertForSequenceClassification::new(var_store.root(), config)?,
                    ))
                } else {
                    Err(RustBertError::InvalidConfigurationError(
                        "You can only supply a MobileBertConfig for MobileBert!".to_string(),
                    ))
                }
            }
            ModelType::Roberta => {
                if let ConfigOption::Roberta(config) = model_config {
                    Ok(Self::Roberta(
                        RobertaForSequenceClassification::new(var_store.root(), config)?,
                    ))
                } else {
                    Err(RustBertError::InvalidConfigurationError(
                        "You can only supply a RobertaConfig for Roberta!".to_string(),
                    ))
                }
            }
            ModelType::XLMRoberta => {
                if let ConfigOption::Roberta(config) = model_config {
                    Ok(Self::XLMRoberta(
                        RobertaForSequenceClassification::new(var_store.root(), config)?,
                    ))
                } else {
                    Err(RustBertError::InvalidConfigurationError(
                        "You can only supply a RobertaConfig for Roberta!".to_string(),
                    ))
                }
            }
            ModelType::Albert => {
                if let ConfigOption::Albert(config) = model_config {
                    Ok(Self::Albert(
                        AlbertForSequenceClassification::new(var_store.root(), config)?,
                    ))
                } else {
                    Err(RustBertError::InvalidConfigurationError(
                        "You can only supply an AlbertConfig for Albert!".to_string(),
                    ))
                }
            }
            ModelType::XLNet => {
                if let ConfigOption::XLNet(config) = model_config {
                    Ok(Self::XLNet(
                        XLNetForSequenceClassification::new(var_store.root(), config)?,
                    ))
                } else {
                    Err(RustBertError::InvalidConfigurationError(
                        "You can only supply an XLNetConfig for XLNet!".to_string(),
                    ))
                }
            }
            ModelType::Bart => {
                if let ConfigOption::Bart(config) = model_config {
                    Ok(Self::Bart(
                        BartForSequenceClassification::new(var_store.root(), config)?,
                    ))
                } else {
                    Err(RustBertError::InvalidConfigurationError(
                        "You can only supply a BertConfig for Bert!".to_string(),
                    ))
                }
            }
            ModelType::Reformer => {
                if let ConfigOption::Reformer(config) = model_config {
                    Ok(Self::Reformer(
                        ReformerForSequenceClassification::new(var_store.root(), config)?,
                    ))
                } else {
                    Err(RustBertError::InvalidConfigurationError(
                        "You can only supply a ReformerConfig for Reformer!".to_string(),
                    ))
                }
            }
            ModelType::Longformer => {
                if let ConfigOption::Longformer(config) = model_config {
                    Ok(Self::Longformer(
                        LongformerForSequenceClassification::new(var_store.root(), config)?,
                    ))
                } else {
                    Err(RustBertError::InvalidConfigurationError(
                        "You can only supply a LongformerConfig for Longformer!".to_string(),
                    ))
                }
            }
            ModelType::FNet => {
                if let ConfigOption::FNet(config) = model_config {
                    Ok(Self::FNet(
                        FNetForSequenceClassification::new(var_store.root(), config)?,
                    ))
                } else {
                    Err(RustBertError::InvalidConfigurationError(
                        "You can only supply a FNetConfig for FNet!".to_string(),
                    ))
                }
            }
            #[cfg(feature = "onnx")]
            ModelType::ONNX => Err(RustBertError::InvalidConfigurationError(
                "A `ModelType::ONNX` ModelType was provided in the configuration with `ModelResources::TORCH`, these are incompatible".to_string(),
            )),
            _ => Err(RustBertError::InvalidConfigurationError(format!(
                "Sequence Classification not implemented for {model_type:?}!",
            ))),
        }?;
        var_store.load(weights_path)?;
        cast_var_store(&mut var_store, config.kind, device);
        Ok(model)
    }

    #[cfg(feature = "onnx")]
    pub fn new_onnx(config: &SequenceClassificationConfig) -> Result<Self, RustBertError> {
        let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
        let environment = onnx_config.get_environment()?;
        let encoder_file = config
            .model_resource
            .get_onnx_local_paths()?
            .encoder_path
            .ok_or(RustBertError::InvalidConfigurationError(
                "An encoder file must be provided for sequence classification ONNX models."
                    .to_string(),
            ))?;

        Ok(Self::ONNX(ONNXEncoder::new(
            encoder_file,
            &environment,
            &onnx_config,
        )?))
    }

    /// Returns the `ModelType` for this SequenceClassificationOption
    pub fn model_type(&self) -> ModelType {
        match *self {
            Self::Bert(_) => ModelType::Bert,
            Self::Deberta(_) => ModelType::Deberta,
            Self::DebertaV2(_) => ModelType::DebertaV2,
            Self::Roberta(_) => ModelType::Roberta,
            Self::XLMRoberta(_) => ModelType::Roberta,
            Self::DistilBert(_) => ModelType::DistilBert,
            Self::MobileBert(_) => ModelType::MobileBert,
            Self::Albert(_) => ModelType::Albert,
            Self::XLNet(_) => ModelType::XLNet,
            Self::Bart(_) => ModelType::Bart,
            Self::Reformer(_) => ModelType::Reformer,
            Self::Longformer(_) => ModelType::Longformer,
            Self::FNet(_) => ModelType::FNet,
            #[cfg(feature = "onnx")]
            Self::ONNX(_) => ModelType::ONNX,
        }
    }

    /// Interface method to forward_t() of the particular models.
    pub fn forward_t(
        &self,
        input_ids: Option<&Tensor>,
        mask: Option<&Tensor>,
        token_type_ids: Option<&Tensor>,
        position_ids: Option<&Tensor>,
        input_embeds: Option<&Tensor>,
        train: bool,
    ) -> Tensor {
        match *self {
            Self::Bart(ref model) => {
                model
                    .forward_t(
                        input_ids.expect("`input_ids` must be provided for BART models"),
                        mask,
                        None,
                        None,
                        None,
                        train,
                    )
                    .decoder_output
            }
            Self::Bert(ref model) => {
                model
                    .forward_t(
                        input_ids,
                        mask,
                        token_type_ids,
                        position_ids,
                        input_embeds,
                        train,
                    )
                    .logits
            }
            Self::Deberta(ref model) => {
                model
                    .forward_t(
                        input_ids,
                        mask,
                        token_type_ids,
                        position_ids,
                        input_embeds,
                        train,
                    )
                    .expect("Error in Deberta forward_t")
                    .logits
            }
            Self::DebertaV2(ref model) => {
                model
                    .forward_t(
                        input_ids,
                        mask,
                        token_type_ids,
                        position_ids,
                        input_embeds,
                        train,
                    )
                    .expect("Error in Deberta V2 forward_t")
                    .logits
            }
            Self::DistilBert(ref model) => {
                model
                    .forward_t(input_ids, mask, input_embeds, train)
                    .expect("Error in distilbert forward_t")
                    .logits
            }
            Self::MobileBert(ref model) => {
                model
                    .forward_t(input_ids, None, None, input_embeds, mask, train)
                    .expect("Error in mobilebert forward_t")
                    .logits
            }
            Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
                model
                    .forward_t(
                        input_ids,
                        mask,
                        token_type_ids,
                        position_ids,
                        input_embeds,
                        train,
                    )
                    .logits
            }
            Self::Albert(ref model) => {
                model
                    .forward_t(
                        input_ids,
                        mask,
                        token_type_ids,
                        position_ids,
                        input_embeds,
                        train,
                    )
                    .logits
            }
            Self::XLNet(ref model) => {
                model
                    .forward_t(
                        input_ids,
                        mask,
                        None,
                        None,
                        None,
                        token_type_ids,
                        input_embeds,
                        train,
                    )
                    .logits
            }
            Self::Reformer(ref model) => {
                model
                    .forward_t(input_ids, None, None, mask, None, train)
                    .expect("Error in Reformer forward pass.")
                    .logits
            }
            Self::Longformer(ref model) => {
                model
                    .forward_t(
                        input_ids,
                        mask,
                        None,
                        token_type_ids,
                        position_ids,
                        input_embeds,
                        train,
                    )
                    .expect("Error in Longformer forward pass.")
                    .logits
            }
            Self::FNet(ref model) => {
                model
                    .forward_t(input_ids, token_type_ids, position_ids, input_embeds, train)
                    .expect("Error in FNet forward pass.")
                    .logits
            }
            #[cfg(feature = "onnx")]
            Self::ONNX(ref model) => {
                let attention_mask = input_ids.unwrap().ones_like();
                model
                    .forward(
                        input_ids,
                        Some(&attention_mask),
                        token_type_ids,
                        position_ids,
                        input_embeds,
                    )
                    .expect("Error in ONNX forward pass.")
                    .logits
                    .unwrap()
            }
        }
    }
}

/// # SequenceClassificationModel for Classification (e.g. Sentiment Analysis)
pub struct SequenceClassificationModel {
    tokenizer: TokenizerOption,
    sequence_classifier: SequenceClassificationOption,
    label_mapping: HashMap<i64, String>,
    device: Device,
    max_length: usize,
}

impl SequenceClassificationModel {
    /// Build a new `SequenceClassificationModel`
    ///
    /// # Arguments
    ///
    /// * `config` - `SequenceClassificationConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
    ///
    /// # Example
    ///
    /// ```no_run
    /// # fn main() -> anyhow::Result<()> {
    /// use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
    ///
    /// let model = SequenceClassificationModel::new(Default::default())?;
    /// # Ok(())
    /// # }
    /// ```
    pub fn new(
        config: SequenceClassificationConfig,
    ) -> Result<SequenceClassificationModel, RustBertError> {
        let vocab_path = config.vocab_resource.get_local_path()?;
        let merges_path = config
            .merges_resource
            .as_ref()
            .map(|resource| resource.get_local_path())
            .transpose()?;

        let tokenizer = TokenizerOption::from_file(
            config.model_type,
            vocab_path.to_str().unwrap(),
            merges_path.as_deref().map(|path| path.to_str().unwrap()),
            config.lower_case,
            config.strip_accents,
            config.add_prefix_space,
        )?;
        Self::new_with_tokenizer(config, tokenizer)
    }

    /// Build a new `SequenceClassificationModel` with a provided tokenizer.
    ///
    /// # Arguments
    ///
    /// * `config` - `SequenceClassificationConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
    /// * `tokenizer` - `TokenizerOption` tokenizer to use for sequence classification.
    ///
    /// # Example
    ///
    /// ```no_run
    /// # fn main() -> anyhow::Result<()> {
    /// use rust_bert::pipelines::common::{ModelType, TokenizerOption};
    /// use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
    /// let tokenizer = TokenizerOption::from_file(
    ///  ModelType::Bert,
    ///  "path/to/vocab.txt",
    ///  None,
    ///  false,
    ///  None,
    ///  None,
    /// )?;
    /// let model = SequenceClassificationModel::new_with_tokenizer(Default::default(), tokenizer)?;
    /// # Ok(())
    /// # }
    /// ```
    pub fn new_with_tokenizer(
        config: SequenceClassificationConfig,
        tokenizer: TokenizerOption,
    ) -> Result<SequenceClassificationModel, RustBertError> {
        let config_path = config.config_resource.get_local_path()?;
        let sequence_classifier = SequenceClassificationOption::new(&config)?;

        let model_config = ConfigOption::from_file(config.model_type, config_path);
        let max_length = model_config
            .get_max_len()
            .map(|v| v as usize)
            .unwrap_or(usize::MAX);
        let label_mapping = model_config.get_label_mapping().clone();
        let device = get_device(config.model_resource, config.device);
        Ok(SequenceClassificationModel {
            tokenizer,
            sequence_classifier,
            label_mapping,
            device,
            max_length,
        })
    }

    /// Get a reference to the model tokenizer.
    pub fn get_tokenizer(&self) -> &TokenizerOption {
        &self.tokenizer
    }

    /// Get a mutable reference to the model tokenizer.
    pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
        &mut self.tokenizer
    }
    /// Classify texts
    ///
    /// # Arguments
    ///
    /// * `input` - `&[&str]` Array of texts to classify.
    ///
    /// # Returns
    ///
    /// * `Vec<Label>` containing labels for input texts
    ///
    /// # Example
    ///
    /// ```no_run
    /// # fn main() -> anyhow::Result<()> {
    /// # use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
    ///
    /// let sequence_classification_model =  SequenceClassificationModel::new(Default::default())?;
    /// let input = [
    ///  "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
    ///  "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
    ///  "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
    /// ];
    /// let output = sequence_classification_model.predict(&input);
    /// # Ok(())
    /// # }
    /// ```
    pub fn predict<'a, S>(&self, input: S) -> Vec<Label>
    where
        S: AsRef<[&'a str]>,
    {
        let (input_ids, token_type_ids) =
            self.tokenizer
                .tokenize_and_pad(input.as_ref(), self.max_length, self.device);
        let output = no_grad(|| {
            let output = self.sequence_classifier.forward_t(
                Some(&input_ids),
                None,
                Some(&token_type_ids),
                None,
                None,
                false,
            );
            output.softmax(-1, Kind::Float).detach().to(Device::Cpu)
        });
        let label_indices = output.as_ref().argmax(-1, true).squeeze_dim(1);
        let scores = output
            .gather(1, &label_indices.unsqueeze(-1), false)
            .squeeze_dim(1);
        let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
        let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();

        let mut labels: Vec<Label> = vec![];
        for sentence_idx in 0..label_indices.len() {
            let label_string = self
                .label_mapping
                .get(&label_indices[sentence_idx])
                .unwrap()
                .clone();
            let label = Label {
                text: label_string,
                score: scores[sentence_idx],
                id: label_indices[sentence_idx],
                sentence: sentence_idx,
            };
            labels.push(label)
        }
        labels
    }

    /// Multi-label classification of texts
    ///
    /// # Arguments
    ///
    /// * `input` - `&[&str]` Array of texts to classify.
    /// * `threshold` - `f64` threshold above which a label will be considered true by the classifier
    ///
    /// # Returns
    ///
    /// * `Vec<Vec<Label>>` containing a vector of true labels for each input text
    ///
    /// # Example
    ///
    /// ```no_run
    /// # fn main() -> anyhow::Result<()> {
    /// # use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
    ///
    /// let sequence_classification_model =  SequenceClassificationModel::new(Default::default())?;
    /// let input = [
    ///  "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
    ///  "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
    ///  "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
    /// ];
    /// let output = sequence_classification_model.predict_multilabel(&input, 0.5);
    /// # Ok(())
    /// # }
    /// ```
    pub fn predict_multilabel(
        &self,
        input: &[&str],
        threshold: f64,
    ) -> Result<Vec<Vec<Label>>, RustBertError> {
        let (input_ids, token_type_ids) =
            self.tokenizer
                .tokenize_and_pad(input.as_ref(), self.max_length, self.device);
        let output = no_grad(|| {
            let output = self.sequence_classifier.forward_t(
                Some(&input_ids),
                None,
                Some(&token_type_ids),
                None,
                None,
                false,
            );
            output.sigmoid().detach().to(Device::Cpu)
        });
        let label_indices = output.as_ref().ge(threshold).nonzero();

        let mut labels: Vec<Vec<Label>> = vec![];
        let mut sequence_labels: Vec<Label> = vec![];

        for sentence_idx in 0..label_indices.size()[0] {
            let label_index_tensor = label_indices.get(sentence_idx);
            let sentence_label = label_index_tensor
                .iter::<i64>()
                .unwrap()
                .collect::<Vec<i64>>();
            let (sentence, id) = (sentence_label[0], sentence_label[1]);
            if sentence as usize > labels.len() {
                labels.push(sequence_labels);
                sequence_labels = vec![];
            }
            let score = output.double_value(sentence_label.as_slice());
            let label_string = self.label_mapping.get(&id).unwrap().to_owned();
            let label = Label {
                text: label_string,
                score,
                id,
                sentence: sentence as usize,
            };
            sequence_labels.push(label);
        }
        if !sequence_labels.is_empty() {
            labels.push(sequence_labels);
        }
        Ok(labels)
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    #[ignore] // no need to run, compilation is enough to verify it is Send
    fn test() {
        let config = SequenceClassificationConfig::default();
        let _: Box<dyn Send> = Box::new(SequenceClassificationModel::new(config));
    }
}
